Skip to content

Fix the QwenImage Attention mask under Ulysses SP#13756

Merged
sayakpaul merged 14 commits into
huggingface:mainfrom
zhtmike:fix_qwen_mask
Jun 8, 2026
Merged

Fix the QwenImage Attention mask under Ulysses SP#13756
sayakpaul merged 14 commits into
huggingface:mainfrom
zhtmike:fix_qwen_mask

Conversation

@zhtmike

@zhtmike zhtmike commented May 15, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

This fixes the issue #13696 . The test should be passed after this PR.

This the problem I found: The mask does not have a one-to-one correspondence with the content.

For QwenImage Pipeline, use the following example

Position:  0  1  2  3  4  5  6  7 | 8  9  10 11
Content:  T0 T1 T2 T3 T4 T5 T6 T7 | I0 I1 I2 I3
Mask:      1  1  0  0  0  0  0  0 | 1  1  1  1

After CP shard (assume 2 ranks)

Rank 0: text=[T0 T1 T2 T3],  image=[I0 I1]  → joint=[T0 T1 T2 T3 I0 I1]
Rank 1: text=[T4 T5 T6 T7],  image=[I2 I3]  → joint=[T4 T5 T6 T7 I2 I3]

After All-to-all

Position:  0  1  2  3 | 4  5  | 6  7  8  9  | 10 11
Content:  T0 T1 T2 T3 | I0 I1 | T4 T5 T6 T7 | I2 I3
           ← rank 0 →          ← rank 1 →

But the mask is not handled correctly

Position:  0  1  2  3  4  5  6  7 | 8  9  10 11
Mask:      1  1  0  0  0  0  0  0 | 1  1  1  1

This PR makes mask correctly assigned

Position:  0  1  2  3 | 4  5 | 6  7  8  9 | 10 11
Mask:      1  1  0  0 | 1  1 | 0  0  0  0 | 1  1

Fixes # (issue)

Before submitting

Who can review?

@sayakpaul

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions github-actions Bot added models size/S PR with diff < 50 LOC labels May 15, 2026
@zhtmike

zhtmike commented May 15, 2026

Copy link
Copy Markdown
Contributor Author

@sayakpaul, I can make this fix more generic by addressing TemplatedUlyssesAttention directly, which should help prevent similar errors in the future.

Here’s my suggestion:

  • Maintain separate local masks for image and text.
  • During all-to-all operations, ensure the mask and the corresponding content always have a one-to-one relationship.
    This way, we avoid numerical inconsistencies caused by mask-related issues.

However, this approach requires handling masks in attention layer locally (and with extra communication cost) and may revert some of the performance improvements introduced in #12702.

So do you have any suggestions?

@sayakpaul

Copy link
Copy Markdown
Member

Could you explain a bit why the actual mask is

Position:  0  1  2  3 | 4  5 | 6  7  8  9 | 10 11
Mask:      1  1  0  0 | 1  1 | 0  0  0  0 | 1  1

@zhtmike

zhtmike commented May 15, 2026

Copy link
Copy Markdown
Contributor Author

Could you explain a bit why the actual mask is

Position:  0  1  2  3 | 4  5 | 6  7  8  9 | 10 11
Mask:      1  1  0  0 | 1  1 | 0  0  0  0 | 1  1

since the original input is (T: token for text; I: Token for image)

Position:  0  1  2  3  4  5  6  7 | 8  9  10 11
Content:  T0 T1 T2 T3 T4 T5 T6 T7 | I0 I1 I2 I3
Mask:      1  1  0  0  0  0  0  0 | 1  1  1  1

after all-to-all

Position:  0  1  2  3 | 4  5  | 6  7  8  9  | 10 11
Content:  T0 T1 T2 T3 | I0 I1 | T4 T5 T6 T7 | I2 I3
           ← rank 0 →          ← rank 1 →

so the mask should be in the correct position matching the content, avoid the wrong attention of KV

Position:  0  1  2  3 | 4  5  | 6  7  8  9  | 10 11
Content:  T0 T1 T2 T3 | I0 I1 | T4 T5 T6 T7 | I2 I3
Mask:      1  1  0  0 | 1  1  | 0  0  0  0  | 1  1

@sayakpaul

Copy link
Copy Markdown
Member

I need @DN6's thoughts on this one. Dhruv, should we tackle it from the forward pass of the model or delegate to the Ulysses utilities we have? Personally, I would prefer doing it from the Ulysses utilities as it would help to keep the forward clean.

@sayakpaul

Copy link
Copy Markdown
Member

@claude-2-serge could you do a review?

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤗 Serge says:

The fix correctly identifies that under Ulysses sequence parallelism, the attention mask must be interleaved to match the post-all-to-all token layout. The approach of chunking text and image masks separately and interleaving them is sound.

Potential Issue

  • _parallel_config not declared on the transformer model class: getattr(self, "_parallel_config", None) at line 919 references self which is the QwenImageTransformer2DModel instance. However, _parallel_config = None is only declared on QwenDoubleStreamAttnProcessor2_0 (line 480), not on the transformer model class itself. Other transformer models (e.g., transformer_flux.py) declare _parallel_config on both the processor and the model class. If the parallel setup code sets _parallel_config on the model via setattr, this works — but it would be more robust and consistent to also declare _parallel_config = None as a class attribute on QwenImageTransformer2DModel.

Style

  • Trailing whitespace on line 926 (else: ).

Tests

  • The PR description references issue #13696 and claims the test should pass, but no new tests are added in this diff. It would be good to confirm the existing test coverage is sufficient.

16 LLM turns · 20 tool calls · 92.9s · 217527 in / 3766 out tokens

Comment thread src/diffusers/models/transformers/transformer_qwenimage.py Outdated
Comment thread src/diffusers/models/transformers/transformer_qwenimage.py Outdated
@DN6

DN6 commented May 18, 2026

Copy link
Copy Markdown
Collaborator

I need @DN6's thoughts on this one. Dhruv, should we tackle it from the forward pass of the model or delegate to the Ulysses utilities we have? Personally, I would prefer doing it from the Ulysses utilities as it would help to keep the forward clean.

Yeah, it would be better to handle via the CP utilities and keep the forward as clean as possible.

@zhtmike

zhtmike commented May 19, 2026

Copy link
Copy Markdown
Contributor Author

Hi @sayakpaul and @DN6 , I've updated the PR using local mask handling.

This is what happens now:

Inputs (T for text token, I for image token):

Position:  0  1  2  3  4  5  6  7 | 8  9  10 11
Content:  T0 T1 T2 T3 T4 T5 T6 T7 | I0 I1 I2 I3
Mask:      1  1  0  0  0  0  0  0 | 1  1  1  1

After CP shard (assume 2 ranks)

Rank 0: text=[T0 T1 T2 T3],  image=[I0 I1], encoder mask=[1,1,0,0], image mask=[1,1] 
→ joint=[T0 T1 T2 T3 I0 I1], mask=[1,1,0,0,1,1]

Rank 1: text=[T4 T5 T6 T7],  image=[I2 I3], encoder mask=[0,0,0,0], image mask=[1,1]  
→ joint=[T4 T5 T6 T7 I2 I3], mask=[0,0,0,0,1,1]

After QKV All-to-all and mask all-gather

Position:  0  1  2  3 | 4  5  | 6  7  8  9  | 10 11
Content:  T0 T1 T2 T3 | I0 I1 | T4 T5 T6 T7 | I2 I3
Mask:      1  1  0  0 | 1  1  |  0  0  0  0 | 1  1

The KV and mask should have the one-to-one correspondence now.

Comment thread src/diffusers/models/transformers/transformer_qwenimage.py Outdated
Comment thread src/diffusers/models/attention_dispatch.py
@github-actions github-actions Bot added size/M PR with diff < 200 LOC and removed size/S PR with diff < 50 LOC labels May 20, 2026
@sayakpaul

Copy link
Copy Markdown
Member

@claude-2-serge could you do a review?

@sayakpaul sayakpaul left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a test?

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤗 Serge says:

This PR correctly fixes the attention mask handling under Ulysses sequence parallelism for QwenImage by moving joint mask construction from the transformer's forward (where it was built once from unsplit tensors) into the per-block attention processor (where it's built from already-sharded encoder_hidden_states_mask and hidden_states). The attention dispatch layers then all-gather the local mask to match the post-all-to-all QKV layout.

Correctness

  • The core fix is sound: under CP, encoder_hidden_states_mask is now split via the _cp_plan (dim=1), passed into each block, and the processor builds the joint mask from the sharded text mask + sharded image ones. The all-gather in TemplatedUlyssesAttention and TemplatedUlyssesAnythingAttention then reconstructs the global mask in the correct rank-concatenated order matching the QKV layout.
  • The "anything" path correctly handles unequal local sizes with pad → all-gather → trim.
  • The guard attn_mask.shape[-1] == S_KV_LOCAL correctly identifies masks that need gathering (local masks) vs. masks that are already global.
  • F.pad on a bool mask defaults to padding with False, which is correct (padded positions should be masked out before trimming).

Minor issues

  • The type annotation attention_mask: None = None is technically valid but unconventional and could confuse tooling/users. A more standard approach would be to keep the original type hint and rely solely on the runtime ValueError.
  • The mask is reconstructed from scratch in every block (cat + unsqueeze), which adds minor overhead compared to the previous approach of building it once. This is the necessary trade-off for correctness under CP, but worth noting.

Suggestions / additional info (dead code trace)

Under the default pipeline call path, the encoder_hidden_states_mask flows correctly from the transformer forward → block → processor → dispatch_attention_fn. The attention_mask parameter on the processor is now effectively dead (always None from external callers, raises if not), which is the intended design — the processor owns mask construction.

23 LLM turns · 26 tool calls · 150.3s · 485476 in / 5617 out tokens

Comment thread src/diffusers/models/transformers/transformer_qwenimage.py Outdated
Comment thread src/diffusers/models/attention_dispatch.py
zhtmike and others added 3 commits May 20, 2026 11:29
@github-actions github-actions Bot added the tests label May 20, 2026
@zhtmike

zhtmike commented May 20, 2026

Copy link
Copy Markdown
Contributor Author

Should there be a test?

Done. Added an accuracy test under ContextParallelTesterMixin. This test should guard this (main branch should fail).
And I've tested with other two models flux and flux2. So I think the threshold should be fine.


# Construct joint attention mask once to avoid reconstructing in every block
# This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @kashif here since this would revert the optimization as part of this PR #12702

I tried to go over #12702, but was not able to find much detail about this optimization. I would love to understand more about the cause of the sync and performance delta, because the pre-built joint mask does not shard correctly under CP

@kashif kashif Jun 3, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @yiyixuxu! the #12702 bit was just me building the joint mask once instead of per-block. "eliminates 60 GPU syncs" was a bad comment on my part, i checked and it's actually 0 syncs, just a plain cat/ones. cost of dropping it is ~0.85ms/fwd eager and basically 0 with compile, so no real loss.

and yeah it has to go for CP anyway: the pre-built mask is over the full unsharded seq so it can't line up after the all-to-all. confirmed with the #13696 repro, main is off by 2.9e-2 and this PR is exactly 0.0. lgtm from me 👍

@zhtmike

zhtmike commented Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

Hi @sayakpaul @yiyixuxu , may I know if there is any concern for this PR?

@sayakpaul

Copy link
Copy Markdown
Member

We're waiting for @kashif for #13756 (comment)

@kashif

kashif commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

@zhtmike one heads up: the new correctness test actually passes on main as-is, so it won't catch the regression. if you bump it to batch_size=2 with per-sample padding like in #13696 (e.g. mask[0, :2]=1, mask[1, :6]=1) it fails on main (2.9e-2) and passes here (0.0). not blocking, the fix itself is correct.

@zhtmike

zhtmike commented Jun 4, 2026

Copy link
Copy Markdown
Contributor Author

Hi @kashif , thank you for your valuable comments!

In my local testing environment, I just used the latest main with two files cherry-picked from the current PR.

  • tests/models/testing_utils/parallelism.py
  • tests/models/transformers/test_models_transformer_qwenimage.py,

and ran the following test:

pytest -rxXs tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerContextParallel

And I got the following error

============================= test session starts ==============================
platform linux -- Python 3.12.12, pytest-9.0.2, pluggy-1.6.0
rootdir: /scratch/fq9hpsac/mikecheung/gitlocal/diffusers
configfile: pyproject.toml
plugins: timeout-2.4.0, xdist-3.8.0, anyio-4.12.1, hydra-core-1.3.2, requests-mock-1.10.0
collected 12 items

tests/models/transformers/test_models_transformer_qwenimage.py .s.s.s.ss [ 75%]
.Fs                                                                      [100%]

=================================== FAILURES ===================================
_ TestQwenImageTransformerContextParallel.test_context_parallel_output_correctness[ulysses] _

self = <tests.models.transformers.test_models_transformer_qwenimage.TestQwenImageTransformerContextParallel object at 0x1550c59ab710>
cp_type = 'ulysses_degree', batch_size = 1

    @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
    def test_context_parallel_output_correctness(self, cp_type, batch_size: int = 1):
        """Verify that CP output is numerically identical to a single-GPU reference forward pass."""
        if not torch.distributed.is_available():
            pytest.skip("torch.distributed is not available.")
    
        if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
            pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
    
        if cp_type == "ring_degree":
            active_backend, _ = _AttentionBackendRegistry.get_active_backend()
            if active_backend == AttentionBackendName.NATIVE:
                pytest.skip("Ring attention is not supported with the native attention backend.")
    
        world_size = 2
        init_dict = self.get_init_dict()
        inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
    
        # Single-GPU reference
        model = self.model_class(**init_dict).eval().to(torch_device)
        state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
        with torch.no_grad():
            ref_output = model(**inputs_dict, return_dict=False)[0].cpu()
    
        # Context-parallel run with the same weights
        inputs_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
        cp_dict = {cp_type: world_size}
    
        master_port = _find_free_port()
        manager = mp.Manager()
        return_dict = manager.dict()
    
        mp.spawn(
            _context_parallel_correctness_worker,
            args=(world_size, master_port, self.model_class, init_dict, state_dict, cp_dict, inputs_cpu, return_dict),
            nprocs=world_size,
            join=True,
        )
    
        assert return_dict.get("status") == "success", (
            f"Context parallel correctness check failed: {return_dict.get('error', 'Unknown error')}"
        )
    
        cp_output = torch.tensor(return_dict["output"])
>       torch.testing.assert_close(ref_output, cp_output, atol=1e-4, rtol=1e-4)
E       AssertionError: Tensor-likes are not close!
E       
E       Mismatched elements: 251 / 256 (98.0%)
E       Greatest absolute difference: 0.0255483016371727 at index (0, 9, 3) (up to 0.0001 allowed)
E       Greatest relative difference: 1.582349419593811 at index (0, 11, 3) (up to 0.0001 allowed)

tests/models/testing_utils/parallelism.py:461: AssertionError

So, in my testing environment, the newly added test correctly guards against the error in #13696. And I think the error should be irrelevant to the batch size, so I prefer to keep it simple with batch_size = 1.

Can you please take a look to see if I have missed something?

@kashif kashif self-requested a review June 4, 2026 08:33
@kashif

kashif commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

no you are right! @zhtmike I was not in the main branch 🙈

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zhtmike

zhtmike commented Jun 5, 2026

Copy link
Copy Markdown
Contributor Author

We're waiting for @kashif for #13756 (comment)

Kindly ping @sayakpaul

@sayakpaul

sayakpaul commented Jun 5, 2026

Copy link
Copy Markdown
Member

@zhtmike

zhtmike commented Jun 5, 2026

Copy link
Copy Markdown
Contributor Author

The two failed tests are related to tests.models.transformers.test_models_transformer_chronoedit. Seems irrelevant to this PR, should we fix it here?

@sayakpaul

Copy link
Copy Markdown
Member

@zhtmike

zhtmike commented Jun 5, 2026

Copy link
Copy Markdown
Contributor Author

Updated. The error is due to the input check of QwenDoubleStreamAttnProcessor2_0 introduced in this PR.

The QwenImageControlNetModel uses a similar pre-built mask optimization as QwenImageTransformer2DModel. Although it is not SP-supported (so there is no such bug for QwenControlNetPipeline currently), I prefer to make QwenImageControlNetModel follow the same style as QwenImageTransformer2DModel. Please check if this change is suitable.

@sayakpaul sayakpaul left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your patience and immense amount of hardwork!

Comment thread src/diffusers/models/controlnets/controlnet_qwenimage.py
Comment thread tests/models/testing_utils/parallelism.py Outdated
)

@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_output_correctness(self, cp_type, batch_size: int = 1):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I didn't make myself clear but why can't we just correctness validation to

def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
?

@zhtmike zhtmike Jun 5, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Originally, I thought a two-level approach was more suitable, since an accurate test is a stricter guard than merely runnable. Now it has been merged into a single test.

@kashif

kashif commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

do we want to add the same all CP modes, including ring?

@zhtmike

zhtmike commented Jun 5, 2026

Copy link
Copy Markdown
Contributor Author

do we want to add the same all CP modes, including ring?

Hi @kashif and @sayakpaul, I think a general test case for the precision of CP is still necessary.

Currently, the change of ContextParallelTesterMixin in this PR covers the accuracy tests for Ulysses SP only (ring attention was skipped since ContextParallelTesterMixin only covers the SDPA backend by default, and the SDPA backend does not support ring attention).

So for now, the modifications in tests only cover all Ulysses‑SP‑supported (and test‑guarded) models, including QwenImage, Flux, and Flux2, as they have the ContextParallelTesterMixin test suites. I have run the tests locally, and the newly added tests passed for these three models. So I think the test should be fine.

For ring attention, I think we should add a similar test under ContextParallelAttentionBackendsTesterMixin, but that should be outside the scope of this PR.

May I know if there are any other concerns?

@sayakpaul

Copy link
Copy Markdown
Member

For ring attention, I think we should add a similar test under ContextParallelAttentionBackendsTesterMixin, but that should be outside the scope of this PR.

May I know if there are any other concerns?

Yeah good idea. Maybe a future PR?

@zhtmike

zhtmike commented Jun 8, 2026

Copy link
Copy Markdown
Contributor Author

For ring attention, I think we should add a similar test under ContextParallelAttentionBackendsTesterMixin, but that should be outside the scope of this PR.

May I know if there are any other concerns?

Yeah good idea. Maybe a future PR?

sure I can help to do this~

Comment on lines +500 to +511
if attention_mask is not None:
raise ValueError(
"QwenDoubleStreamAttnProcessor2_0 does not accept an external attention_mask. "
"Pass encoder_hidden_states_mask to let the processor build the joint mask."
)

if encoder_hidden_states_mask is not None:
seq_img = hidden_states.shape[1]
image_mask = torch.ones((hidden_states.shape[0], seq_img), dtype=torch.bool, device=hidden_states.device)
attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
attention_mask = attention_mask[:, None, None, :]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the controlnet also have similar changes or not because it doesn't define _cp_plan?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the controlnet model is calling QwenDoubleStreamAttnProcessor2_0 here

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate that?

@zhtmike zhtmike Jun 8, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The controlnet model itself does not have the CP plan. However, because we have modified QwenDoubleStreamAttnProcessor2_0 in this PR with an extra guard:

if attention_mask is not None:
    raise ValueError(
        "QwenDoubleStreamAttnProcessor2_0 does not accept an external attention_mask. "
        "Pass encoder_hidden_states_mask to let the processor build the joint mask."
    )

So we either need to:

  • modify the controlnet following the QwenImage transformer’s change; or
  • drop the guard and avoid touching the controlnet.

In the future, we may need to add support for the SP implementation for controlnet. In that case, option 1 may be a better solution, since it will avoid such mask bugs more easily and is also consistent with the style of the Qwen‑Image transformer. So personally, I prefer option 1.

@sayakpaul sayakpaul left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @zhtmike! I will merge after running the tests locally myself :)

@sayakpaul

Copy link
Copy Markdown
Member

@bot /style

@sayakpaul

Copy link
Copy Markdown
Member

I ran pytest tests/models/transformers/test_models_transformer_qwenimage.py -k "context_parallel_inference" on your branch and main. Both are passing. Is this not the right way to test?

@zhtmike

zhtmike commented Jun 8, 2026

Copy link
Copy Markdown
Contributor Author

I ran pytest tests/models/transformers/test_models_transformer_qwenimage.py -k "context_parallel_inference" on your branch and main. Both are passing. Is this not the right way to test?

you might need to pick tests/models/transformers/test_models_transformer_qwenimage.py in the branch. Since in the main branch the testing mask is all true, it will not reflect the real senerio.

@sayakpaul sayakpaul merged commit 86dab15 into huggingface:main Jun 8, 2026
14 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models size/M PR with diff < 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants